from .text_base import TextBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..smp import *


class TextMCQDataset(TextBaseDataset):
    TYPE = "MCQ"

    DATASET_URL = {}

    DATASET_MD5 = {}

    def build_prompt(self, line):

        if isinstance(line, int):
            line = self.data.iloc[line]

        question = line["question"]
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        options_prompt = "Options:\n"
        for key, item in options.items():
            options_prompt += f"{key}. {item}\n"
        hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
        prompt = ""
        if hint is not None:
            prompt += f"Hint: {hint}\n"
        prompt += f"Question: {question}\n"
        if len(options):
            prompt += options_prompt
            prompt += "Please select the correct answer from the options above. \n"

        msgs = []

        msgs.append(dict(type="text", value=prompt))

        return msgs

    def evaluate(self, eval_file, **judge_kwargs):
        from .utils.multiple_choice import (
            report_acc,
            report_acc_MMT,
            mcq_circular_eval,
            mcq_vanilla_eval,
        )

        # assert dataset is not None
        dataset_map = {
            "MMBench_TEST_EN": "MMBench",
            "MMBench_TEST_EN_V11": "MMBench_V11",
            "MMBench_TEST_CN": "MMBench_CN",
            "MMBench_TEST_CN_V11": "MMBench_CN_V11",
        }
        dataset = self.dataset_name
        if dataset in dataset_map:
            dataset = dataset_map[dataset]
        nproc = judge_kwargs.pop("nproc", 4)

        circular = False

        suffix = eval_file.split(".")[-1]
        model = judge_kwargs.get("model", "exact_matching")
        assert model in ["chatgpt-0125", "exact_matching", "gpt-4-0125"]
        name_str_map = {"chatgpt-0125": "openai", "gpt-4-0125": "gpt4"}
        name_str = name_str_map[model] if model in name_str_map else model

        if model == "exact_matching":
            model = None
        elif gpt_key_set():
            model = build_judge(**judge_kwargs)
            if not model.working():
                warnings.warn(
                    "OPENAI API is not working properly, will use exact matching for evaluation"
                )
                warnings.warn(DEBUG_MESSAGE)
                model = None
        else:
            warnings.warn(
                "OPENAI_API_KEY is not set properly, will use exact matching for evaluation"
            )
            model = None

        result_file = eval_file.replace(f".{suffix}", f"_{name_str}_result.pkl")

        data = load(eval_file)
        data = data.sort_values(by="index")
        data["prediction"] = [str(x) for x in data["prediction"]]
        # If not choice label, then use lower case
        for k in data.keys():
            data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(
                k
            )

        meta = self.data
        meta_q_map = {x: y for x, y in zip(meta["index"], meta["question"])}
        data_map = {x: y for x, y in zip(data["index"], data["question"])}
        for k in data_map:
            assert (
                k in meta_q_map
            ), f"eval_file should be the same as or a subset of dataset {self.dataset_name}"

        if circular:
            data = mcq_circular_eval(
                model, data, meta, nproc, result_file, self.dataset_name
            )
        else:
            data = mcq_vanilla_eval(
                model, data, meta, nproc, result_file, self.dataset_name
            )

        # load split
        dump(data, eval_file.replace(f".{suffix}", f"_{name_str}_result.{suffix}"))
        data = load(eval_file.replace(f".{suffix}", f"_{name_str}_result.{suffix}"))

        # May have different report acc functions for different datasets
        if "MMT" in dataset:
            acc = report_acc_MMT(data)
        else:
            acc = report_acc(data)

        score_file = eval_file.replace(f".{suffix}", "_acc.csv")
        dump(acc, score_file)

        return acc


class CustomTextMCQDataset(TextMCQDataset):

    def load_data(self, dataset):
        data_path = osp.join(LMUDataRoot(), f"{dataset}.tsv")

        if file_size(data_path, "GB") > 1:
            local_path = data_path.replace(".tsv", "_local.tsv")
            if not osp.exists(local_path) or os.environ.get("FORCE_LOCAL", None):
                from ..tools import LOCALIZE

                LOCALIZE(data_path, local_path)
            data_path = local_path
        return load(data_path)
